[WIP] Super-resolution Benthic Object Detection¶
In our last notebook, we evaluated MBARI's Montery Bay Bethic Object Detector on the external out TrashCAN dataset. There, we found that model results were good, given the slight adaptations we had to make to compare against the new annotaitons. However, we also saw potential for increased model performance when applying some types of upscaling to the input images.
In this note, we will build a workflow to easily feed the inputs from the TrashCAN dataset through a super-resolution (SR) layer before feeding them into the MBARI model. We will then evaluate the performance of the model with and without the SR layer to see how much, if any improvement we can achieve. It will also be important to measure copmutation time and memory usage to see if the tradeoff is worth it.
Before diving into the code, we will first dicuss our motivations for applying a super-resolution layer to the input images, along with some fundamentally different architectures useful for SR. Specifically, we will look at GANs, Transformers, and the more traditional CNN-based SR models.
In later notes, we can explore fine-tuning the set-up built here. This is important to keep in mind when making decisions about how to implement the super-resolution layer.
%load_ext autoreload
%autoreload 2
#%pip install -r ../requirements.txt
from fathomnet.models.yolov5 import YOLOv5Model
from IPython.display import display
from pathlib import Path
from PIL import Image
from pycocotools.coco import COCO
from typing import List
import json
import onnxruntime
import os
import numpy as np
root_dir = Path(os.getcwd().split("personal/")[0])
repo_dir = root_dir / "personal" / "ocean-species-identification"
# reuse some code from preivous notebook ported to src.data
os.chdir(repo_dir)
from src.data import *
Load¶
We will start by loading the TrashCAN dataset, the MBARI model, and label map between the two. Aside from path building, each requires only a single line of code to load.
data_dir = root_dir / "data" / "TrashCAN"
benthic_model_weights_path = root_dir / "personal" / "models" / "fathomnet_benthic" / "mbari-mb-benthic-33k.pt"
benthic_model = YOLOv5Model(benthic_model_weights_path)
trashcan_data = COCO(data_dir / "dataset" / "material_version" / "instances_val_trashcan.json")
benthic2trashcan_ids = json.load(open(repo_dir / "data" / "benthic2trashcan_ids.json"))
Using cache found in /Users/per.morten.halvorsen@schibsted.com/.cache/torch/hub/ultralytics_yolov5_master YOLOv5 🚀 2024-2-24 Python-3.11.5 torch-2.2.1 CPU Fusing layers... Model summary: 476 layers, 91841704 parameters, 0 gradients Adding AutoShape...
loading annotations into memory... Done (t=0.12s) creating index... index created!
# benthic_model._model.eval()
Background¶
Motivation¶
The general idea with this setup is that if we can enhance some of the low-level, fine-grained patterns in the images, the larger picture may be easier to interpret for the model.
The Benthic Object Detector is trained from a YOLOv5 bacbone. This architecture contains a deep stack of convolutional layers, making it good at picking up low-level features, such as edges and textures, in its first few layers. The deeper CNN layers are then able to piece these features together to produce high-level predictions about the full objects in the image, as done in object detection. A super-resolution layer's job would then be to enhance the low-level features to make them more clear for the Benthic model to better work its math-magic on.
Convolutions¶
Breifly explained, convolutional layers apply a filter to a sliding window of the input image at a time, before pooling, allowing the strongest signal from each "glimpse" to pass through. The more filters a model has, the more flexibility it has to "focus" on different patterns or features in the inputs. This is essential for upscaling, since the model will need to learn to correctly predict the values of the new pixels it adds to the image.
Inputs are often split into a number of channels, which can be thought of as "slices" of the input image, like RGB layers. A convolutional filter will then also have the same number of channels, with tunable parameters for each one.
In thhe animation below, we can see eight 3x3 convolutional filter applied to a 7x6 input orignially split into 8 channels.
Check out this great visualization of convolutions from Animated AI on Youtube
The results of the matrix multiplication between the filter and the input are then pooled to produce a single value in the output. This means outputs sizes vary greatly depending on the size of the input, filter, stride, and dialation. As you can see, the more filters you have, the more layers of features you get in your outputs.
Super-resolution models¶
In our first note, we chose the ABPN based model as our inital upsampler due to its light-weight architecture and ease-of-use. In this note, we will consider a few hand-picked models to measure any performance differences between architectures. There are a few different base components to consider when choosing a super-resolution model. Some emphasize context, others emphasize flexibility. The base architectural components we will consider here are: GANs, CNNs, and Attention. The papers for the following architectures can be found in the research/ folder.
ABPN: Anchor-based Plain Net for Mobile Image Super-Resolution
This is the model we used in our first notebook, via this playground. As an 8INT quantized model, the SR Mobile PyTorch model is aimed to be as small, yet efficient as possible, in order to run on mobile devices. It can "restore twenty-seven 1080P images (x3 scale) per second, maintaining good perceptual quality as well." In other words, it is fast and computationally cheap.
The ABPN architecture makes use of convoultions, residual connections, and a pixel shuffle layer.
A residual connection is a way to "skip" a layer in a neural network, sending an untouched version of the input to a later layer, to preserve the original signal. In super resolution, this would help the model maintain important features of the input when upscaling the image. Different architectures make use of residual connections rather differently, sometimes in very complex manners. The ABPN model however, uses a simple residual connection, as seen in the image below.
Notice one of the last layers called "depth to space". This is a pixel shuffle layer, which is a way to upscale the input by rearranging the pixels in the input. This component leverages the features extracted by the upstream CNN layers to predict the values of the final pixels in the output image. Animated AI has another great visualization to explain this concept.
The model we will use here is a PyTorch adaptation of the original model, which was written in TensorFlow. According the to SR Mobile PyTorch GitHub, the architecture was ported as is, with minimal changes. We opted for the PyTorch version of this architecture, since it was easier to use out-of-the-box, and because PyTorch subjectively easier to work with.
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks
Published as a state-of-the-art super resolution model in 2018, the ESRGAN is much more complex than the previous model.
It was trained using a generator-discriminator network, similar to SRGAN with a 19 layer VGG network as the generator.
Testing on this model will allow us to observe the difference larger models can make in performance.
Two major changes from the SRGAN model were made to the generator in ESRGAN:
1. They removed batch normalization inside the dense blocks.
The reasoning behind this was that batch normalization introduces artifacts during evaluation since the model is using an estimated mean and variance for the normalization from the training. This becomes a problem in data sets where training and test sets can vary quite a lot.
Additionally, empirical observations have shown that removing batch-normalization increases generalization and performance, while lowering computational cost.
2. They introduced Residual-in-Residual Dense Block (RRDB).
Connecting all layers through residual connectors is expected to boost perforamnce, as it allows the model to learn more complex features.
This is the opposite extreme of the ABPN model, which uses a single residual connection.
There has been other work on these "multilevel residual networks" (Zhang et. al 2017) that have shown to improved performance in other tasks when using such residual blocks. It is however important to keep in mind that this added complexity will also increase the required compute resources. If we were to fit our pipeline into the head of an underwater robot, we would need to consider the tradeoff between performance and computational cost.
Some other key improves from that architecture include:
- Relativistic discriminator: prediicts a probability of an image being real or not, rather than a binary decision.
- Not sure if they removed the batch normalization from the discriminator as well, but can be looked into further later.
- Refined perceptual loss: constraining (applying) the loss on the feature before activation functions to preserve more information.
- Introduce network interpolation: using a threshold to balancing two models: the fine-tuned GAN and a peak signal noise ratio model. This allows to easily balance quality of outputs without having to retrain the model.
HAT: Hybrid Attention Transformer for Image Restoration
Recently, Tranformer models have been shown to perform well in computer vision tasks, after their success in NLP.
The HAT model is a hybrid model that combines both attention and convolutions.
It was published in late 2023, and currently holds the state-of-the-art title for most of the classic super-resolution benchmark datasets.
It builds off of other transformer based SR models such as Swin Tranformer and RCAN .
A transformer is a component of neural networks that uses cross- or self-attention.
Attention is a matrix multiplication operation that adds some contextual information to the each token or pixel.
The image below shows what this looks like in the context of language processing.
By Peltarion on YouTube
Building off the SwimIR model, the HAT model uses residual-in-residual connections, to propagate information through the network. The compoents can be broken into three main parts:
- Shallow feature extraction: a few layers of convolutions to extract low-level features from the input.
- Deep feature extraction: a few layers of transformer blocks to extract high-level features from the input.
- Image reconstruction: a few layers of convolutions to upscale the input.
I'll save the details for the paper, but some of the key points are:
- Each residual hybrid attantion group (RHAG) contains a cross-attention block, a self-attention block, and a channel attention block.
- RHAGs also utilize a convolution as their final sub-layer, to help the transformer "get a better visual representation" of the input.
- Pixel shuffling is used in the upscale module, similar to the ABPN model.
- They introduce overlapping cross attention, meaning the windows of the attention blocks overlap. This differs from the original image transformer models, which used non-overlapping, same-sized windows for the queries, keys, and values.
For more details, read the white paper
The HAT model yet another deep and complex model, with a large number of layers and parameters. This model is expected to perform well, but at the cost of computational resources.
Implementation¶
Now that we have a rough understanding of the few models we will look into, lets load them into our environment and feed them our TrashCAN data.
Steps¶
For each model, we will execute the following steps:
- (single) feed image from COCO through super resolution models
- (single) feed outputs through MBARI model & show detections
- (full) wrap dataflow into a pipeline
- (full) run classification on a given category of TrashCAN images
Single Image SR (SISR)¶
ABPN¶
# abpn_model_path = root_dir / "personal" / "models" / "sr_mobile_python" / "models_modelx2.ort"
abpn_model_path = root_dir / "personal" / "models" / "sr_mobile_python" / "models_modelx4.ort"
# get only starfish images using src.data.image_from_category
starfish_images = images_per_category("animal_starfish", trashcan_data, data_dir / "dataset" / "material_version" / "val")
example_image_path = starfish_images[3]
example_image = Image.open(example_image_path)
print(np.array(example_image).shape)
display(example_image)
(270, 480, 3)
The following methods were adapted from the sr_mobile_python's inference module.
import numpy as np
import cv2
import onnxruntime
from glob import glob
import os
from tqdm.auto import tqdm
class ABPN:
def __init__(self, model_path: str, store:bool=True):
self.model_path = model_path
self.saved_imgs = {}
self.store = store
def pre_process(self, img: np.array) -> np.array:
# H, W, C -> C, H, W
img = np.transpose(img[:, :, 0:3], (2, 0, 1))
# C, H, W -> 1, C, H, W
img = np.expand_dims(img, axis=0).astype(np.float32)
return img
def post_process(self, img: np.array) -> np.array:
# 1, C, H, W -> C, H, W
img = np.squeeze(img)
# C, H, W -> H, W, C
img = np.transpose(img, (1, 2, 0))
return img
def save(self, img: np.array, save_name: str) -> None:
# cv2.imwrite(save_name, img)
if self.store:
self.saved_imgs[save_name] = img
def inference(self, img_array: np.array) -> np.array:
# unasure about ability to train an onnx model from a Mac
ort_session = onnxruntime.InferenceSession(self.model_path)
ort_inputs = {ort_session.get_inputs()[0].name: img_array}
ort_outs = ort_session.run(None, ort_inputs)
return ort_outs[0]
def upscale(self, image_paths: List[str]):
outputs = []
for image_path in tqdm(image_paths):
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# filename = os.path.basename(image_path)
if img.ndim == 2:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
if img.shape[2] == 4:
alpha = img[:, :, 3] # GRAY
alpha = cv2.cvtColor(alpha, cv2.COLOR_GRAY2BGR) # BGR
alpha_output = self.post_process(
self.inference(self.pre_process(alpha))
) # BGR
alpha_output = cv2.cvtColor(alpha_output, cv2.COLOR_BGR2GRAY) # GRAY
img = img[:, :, 0:3] # BGR
image_output = self.post_process(
self.inference(self.pre_process(img))
) # BGR
output_img = cv2.cvtColor(image_output, cv2.COLOR_BGR2BGRA) # BGRA
output_img[:, :, 3] = alpha_output
self.save(output_img, Path(image_path).stem)
elif img.shape[2] == 3:
image_output = self.post_process(
self.inference(self.pre_process(img))
) # BGR
self.save(image_output, Path(image_path).stem)
outputs += [image_output.astype('uint8')]
return outputs
abpn_model = ABPN(abpn_model_path)
example_upscaled = abpn_model.upscale([str(example_image_path)])[0]
print(example_upscaled.shape)
Image.fromarray(example_upscaled)
0%| | 0/1 [00:00<?, ?it/s]
(1080, 1920, 3)
# check the scale of the super-resolution image
x_scale = example_upscaled.shape[1] / example_image.size[0]
y_scale = example_upscaled.shape[0] / example_image.size[1]
(x_scale, y_scale)
(4.0, 4.0)
ESRGAN¶
We can use the implementation of Real-ESRGAN by akhaliq on HuggingFace.
import cv2
import glob
import os
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
import json
import numpy as np
import onnxruntime
import PIL.Image
import tqdm
class ESRGAN:
def __init__(self,
input_path='inputs',
model_name='RealESRGAN_x4plus',
output_path='results',
outscale=4,
suffix='out',
tile=0,
tile_pad=10,
pre_pad=0,
face_enhance=False,
half=False,
alpha_upsampler='realesrgan',
ext='auto'
):
self.args = None
self.upsampler = None
self.face_enhancer = None
self.input_path = input_path
self.model_name = model_name
self.output_path = output_path
self.outscale = outscale
self.suffix = suffix
self.tile = tile
self.tile_pad = tile_pad
self.pre_pad = pre_pad
self.face_enhance = face_enhance
self.half = half
self.alpha_upsampler = alpha_upsampler
self.ext = ext
self.model = None
self.netscale = None
self.model_path = None
def main(self):
"""Inference demo for Real-ESRGAN.
"""
# determine models according to model names
model_name = self.model_name.split('.')[0]
if model_name in ['RealESRGAN_x4plus', 'RealESRNet_x4plus']: # x4 RRDBNet model
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.netscale = 4
elif model_name in ['RealESRGAN_x4plus_anime_6B']: # x4 RRDBNet model with 6 blocks
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
self.netscale = 4
elif model_name in ['RealESRGAN_x2plus']: # x2 RRDBNet model
self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=2)
self.netscale = 2
elif model_name in [
'RealESRGANv2-anime-xsx2', 'RealESRGANv2-animevideo-xsx2-nousm', 'RealESRGANv2-animevideo-xsx2'
]: # x2 VGG-style model (XS size)
self.model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=2, act_type='prelu')
self.netscale = 2
elif model_name in [
'RealESRGANv2-anime-xsx4', 'RealESRGANv2-animevideo-xsx4-nousm', 'RealESRGANv2-animevideo-xsx4'
]: # x4 VGG-style model (XS size)
self.model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=16, upscale=4, act_type='prelu')
self.netscale = 4
# determine model paths
self.model_path = os.path.join('.', model_name + '.pth')
if not os.path.isfile(self.model_path):
self.model_path = os.path.join('.', model_name + '.pth')
if not os.path.isfile(self.model_path):
raise ValueError(f'Model {model_name} does not exist.')
# restorer
self.upsampler = RealESRGANer(
scale=self.netscale,
model_path=self.model_path,
model=self.model,
tile=self.tile,
tile_pad=self.tile_pad,
pre_pad=self.pre_pad,
half=self.half
)
# def enhance_images(self):
# if self.face_enhance: # Use GFPGAN for face enhancement
# from gfpgan import GFPGANer
# self.face_enhancer = GFPGANer(
# model_path='https://github.com/TencentARC/GFPGAN/releases/download/v0.2.0/GFPGANCleanv1-NoCE-C2.pth',
# upscale=self.outscale,
# arch='clean',
# channel_multiplier=2,
# bg_upsampler=self.upsampler)
# os.makedirs(self.output_path, exist_ok=True)
# if os.path.isfile(self.input_path):
# paths = [self.input_path]
# else:
# paths = sorted(glob.glob(os.path.join(self.input_path, '*')))
# for idx, path in enumerate(paths):
# imgname, extension = os.path.splitext(os.path.basename(path))
# print('Testing', idx, imgname)
# img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
# if len(img.shape) == 3 and img.shape[2] == 4:
# img_mode = 'RGBA'
# else:
# img_mode = None
# try:
# if self.face_enhance:
# _, _, output = self.face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True)
# else:
# output, _ = self.upsampler.enhance(img, outscale=self.outscale)
# except RuntimeError as error:
# print('Error', error)
# print('If you encounter CUDA out of memory, try to set --tile with a smaller number.')
# else:
# if self.args.ext == 'auto':
# extension = extension[1:]
# else:
# extension = self.args.ext
# if img_mode == 'RGBA': # RGBA images should be saved in png format
# extension = 'png'
# save_path = os.path.join(self.args.output, f'{imgname}_{self.args.suffix}.{extension}')
# cv2.imwrite(save_path, output)
if __name__ == '__main__':
esrgan = ESRGAN()
esrgan.main()
# esrgan.enhance_images()
HAT: Hybrid Attention Transformer for Image Restoration¶
SISR to Benthic Object Detector¶
example_detections = benthic_model._model(example_image)
upscaled_detections = benthic_model._model(example_upscaled)
example_detections.show()
upscaled_detections.show()
Here we see what we are trying to acheive with this super-resolution layer.
TODO: Look into why the color is off. The hue seems to be a bit redder in th eupscaled version.
The first few examples had little to no improvement, so we went with index 4 to see when such a pipeline might be useful. This is a form of cherry-picking our results, but mainly for visualization purposes. The final evaluation will fairly compare the methods, without any influence on input data.
Build prediction pipeline¶
abpn_model_path
PosixPath('/Users/per.morten.halvorsen@schibsted.com/personal/models/sr_mobile_python/models_modelx4.ort')
from fathomnet.models.yolov5 import YOLOv5Model
class YOLOv5ModelWithUpscale(YOLOv5Model):
def __init__(self, detection_model_path: str, upscale_model_path: str = None):
super().__init__(detection_model_path)
self.upscale_model_path = upscale_model_path
def forward(self, X: List[str]):
if self.upscale_model_path:
X = upscale(X, self.upscale_model_path)
return self._model(X)
upscale_model = YOLOv5ModelWithUpscale(benthic_model_weights_path, abpn_model_path)
upscaled_detections = upscale_model.forward([str(example_image_path)]) # upscale expects a list of image paths
upscaled_detections.show()
Using cache found in /Users/per.morten.halvorsen@schibsted.com/.cache/torch/hub/ultralytics_yolov5_master YOLOv5 🚀 2024-2-24 Python-3.11.5 torch-2.2.1 CPU Fusing layers... Model summary: 476 layers, 91841704 parameters, 0 gradients Adding AutoShape...
0%| | 0/1 [00:00<?, ?it/s]
I'll add a somewhat hacky fix here, to make sure our call methods between the two models are the same. This will help standardize our evaluation setup later on.
def forward(self, X: List[str]):
return self._model(X)
benthic_model.forward = forward.__get__(benthic_model)
example_detections = benthic_model.forward([str(example_image_path)])
example_detections.show()
Full category classifications¶
As a sanity check, let us see if we can produce predictions for a large number of images. Here, we'll use the "Eel" class, since that category seemed to have fewest images, as observed in the previous notebook.
N = 5
# raw_starfish_detections = benthic_model.forward(starfish_images[:N])
# upscaled_starfish_detections = upscale_model.forward(starfish_images[:N])
# raw_starfish_detections.show()
# upscaled_starfish_detections.show()
0%| | 0/5 [00:00<?, ?it/s]
Great! Now we can easily feed the TrashCAN dataset through the super-resolution model and then through the MBARI model. Let's get the evaluation methods developed in the last notebook and us ethem to compare our models.
Evaluation¶
Our evaluation will contain three main steps:
- Import the methods from our previous notebook
- Evaluate both the
benthic_modeland theupscaler_model - Compare the results of the two models
We start by importing the methods from the previous notebook. These methods were ported to stand-alone code, for cleaner imports.
from src.evaluation import *
# rebuild somneeded params locally
trashcan_ids = {
row["supercategory"]: id
for id, row in trashcan_data.cats.items()
}
# find trash index
trash_idx = list(benthic_model._model.names.values()).index("trash")
print(benthic_model._model.names[trash_idx])
# find trash labels
trashcan_trash_labels = {
id: name
for name, id in trashcan_ids.items()
if name.startswith("trash")
}
trashcan_trash_labels
trash
{9: 'trash_etc',
10: 'trash_fabric',
11: 'trash_fishing_gear',
12: 'trash_metal',
13: 'trash_paper',
14: 'trash_plastic',
15: 'trash_rubber',
16: 'trash_wood'}
# replace str keys with ints
benthic2trashcan_ids = {
int(key): value
for key, value in benthic2trashcan_ids.items()
}
Run evaluation on both models¶
raw_starfish_metrics = evaluate_model(
category="animal_starfish",
data=trashcan_data,
model=benthic_model,
id_map=benthic2trashcan_ids,
# verbose=2,
# N=5,
one_idx=trash_idx,
many_idx=trashcan_trash_labels,
exclude_ids=[trashcan_ids["rov"], trashcan_ids["plant"]],
path_prefix=data_dir / "dataset" / "material_version" / "val"
)
raw_starfish_metrics
Precision: 0.39534882801514354 Recall: 0.08415841542495835 Average IoU: tensor(0.31285)
{'precision': 0.39534882801514354,
'recall': 0.08415841542495835,
'iou': tensor(0.31285),
'time': 41.07221722602844}
upscale_starfish_metrics = evaluate_model(
category="animal_starfish",
data=trashcan_data,
model=upscale_model,
id_map=benthic2trashcan_ids,
# verbose=2,
# N=5,
one_idx=trash_idx,
many_idx=trashcan_trash_labels,
exclude_ids=[trashcan_ids["rov"], trashcan_ids["plant"]],
path_prefix=data_dir / "dataset" / "material_version" / "val",
x_scale=x_scale,
y_scale=y_scale
)
upscale_starfish_metrics
0%| | 0/46 [00:00<?, ?it/s]
Precision: 0.20312499682617194 Recall: 0.06435643532496814 Average IoU: tensor(0.15016)
{'precision': 0.20312499682617194,
'recall': 0.06435643532496814,
'iou': tensor(0.15016),
'time': 45.43256592750549}
Metrics for all categories¶
def evaluate_both_models(category, N=-1, verbose=False):
raw_metrics = evaluate_model(
category=category,
data=trashcan_data,
model=benthic_model,
id_map=benthic2trashcan_ids,
verbose=verbose,
N=N,
one_idx=trash_idx,
many_idx=trashcan_trash_labels,
exclude_ids=[trashcan_ids["rov"], trashcan_ids["plant"]],
path_prefix=data_dir / "dataset" / "material_version" / "val"
)
upscale_metrics = evaluate_model(
category=category,
data=trashcan_data,
model=upscale_model,
id_map=benthic2trashcan_ids,
verbose=verbose,
N=N,
one_idx=trash_idx,
many_idx=trashcan_trash_labels,
exclude_ids=[trashcan_ids["rov"], trashcan_ids["plant"]],
path_prefix=data_dir / "dataset" / "material_version" / "val",
x_scale=x_scale,
y_scale=y_scale
)
return raw_metrics, upscale_metrics
raw_fish_metrics, upscale_fish_metrics = evaluate_both_models("animal_fish")
print(raw_fish_metrics)
print(upscale_fish_metrics)
0%| | 0/100 [00:00<?, ?it/s]
{'precision': 0.4166666608796297, 'recall': 0.11406844063091848, 'iou': tensor(0.32055), 'time': 81.14837098121643}
{'precision': 0.30769229585798863, 'recall': 0.030418250834911596, 'iou': tensor(0.21463), 'time': 83.67938709259033}
raw_eel_metrics, upscale_eel_metrics = evaluate_both_models("animal_eel")
print(raw_eel_metrics)
print(upscale_eel_metrics)
0%| | 0/73 [00:00<?, ?it/s]
{'precision': 0.19565216965973545, 'recall': 0.05142857113469388, 'iou': tensor(0.14774), 'time': 48.71544289588928}
{'precision': 0.0, 'recall': 0.0, 'iou': 0.0, 'time': 53.536354064941406}
raw_crab_metrics, upscale_crab_metrics = evaluate_both_models("animal_crab")
print(raw_crab_metrics)
print(upscale_crab_metrics)
0%| | 0/39 [00:00<?, ?it/s]
{'precision': 0.07692307573964499, 'recall': 0.03246753225670434, 'iou': tensor(0.06751), 'time': 36.593260049819946}
{'precision': 0.006535947669699688, 'recall': 0.006493506451340867, 'iou': tensor(0.00869), 'time': 38.949223041534424}
raw_trash_metrics, upscale_trash_metrics = evaluate_both_models("trash_plastic")
print(raw_trash_metrics)
print(upscale_fish_metrics)
0%| | 0/340 [00:00<?, ?it/s]
Conclusion¶
Wrap things up and make a plan for next steps.
Idea:
- Fine-tuning
- add extra final output layer to MBARI model mapping 691 outputs to 17 TrashCAN labels
- select iamges from FathomNet that have annotations for the TrashCAN labels
- fine-tune the model on the FathomNet images
- Evaluate on TrashCAN dataset
- Deeper compare analysis
- compare the performance of the super-resolution models on the FathomNet dataset
- Manually observe annotations from TrashCAN and FathomNet to empirically evaluate quality